{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import numpy as np\n",
    "import torch as tc\n",
    "sys.path.append('..')\n",
    "import distr as ds\n",
    "\n",
    "shape_s, shape_z = (2,3), (2,2)\n",
    "shape_bat = (30000,)\n",
    "mu_S, mu_Z = -1., 1.\n",
    "std_S, std_Z = 1.3, 1.\n",
    "corr_SZ = .7\n",
    "dim_s, dim_z = np.array(shape_s).prod(), np.array(shape_z).prod()\n",
    "\n",
    "S = mu_S + std_S * np.random.randn(*(shape_bat+shape_s)).astype(np.float32)\n",
    "S_normal_flat = ((S - mu_S) / std_S).reshape(shape_bat+(dim_s,))\n",
    "Z_normal_flat = S_normal_flat[..., :dim_z] if dim_z <= dim_s \\\n",
    "    else tc.cat([S_normal_flat, tc.zeros(shape_bat+(dim_z-dim_s,), dtype=np.float32)], dim=-1)\n",
    "mu_Z1S = mu_Z + corr_SZ*std_Z * Z_normal_flat.reshape(shape_bat+shape_z)\n",
    "std_Z1S = std_Z * np.sqrt(1. - corr_SZ**2)\n",
    "Z = mu_Z1S + std_Z1S * np.random.randn(*(shape_bat+shape_z)).astype(np.float32)\n",
    "\n",
    "device = tc.device(\"cuda:0\" if tc.cuda.is_available() else \"cpu\")\n",
    "S, Z = tc.from_numpy(S).to(device), tc.from_numpy(Z).to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Learning by Normal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "mu_s = tc.randn(shape_s, requires_grad=True, device=device)\n",
    "std_s_param = tc.randn(shape_s, requires_grad=True, device=device)\n",
    "mu_z = tc.randn(shape_z, requires_grad=True, device=device)\n",
    "std_z_param = tc.randn(shape_z, requires_grad=True, device=device)\n",
    "corr_param = tc.randn(1, requires_grad=True, device=device)\n",
    "\n",
    "def corr_sz(): return 1. - (2*tc.sigmoid(corr_param)-1.)**2\n",
    "def std_s(): return tc.exp(std_s_param)\n",
    "def std_z(): return tc.exp(std_z_param)\n",
    "\n",
    "def mu_z1s(s):\n",
    "    s_normal_flat = ((s - mu_s) / std_s()).reshape(shape_bat+(dim_s,))\n",
    "    z_normal_flat = s_normal_flat[..., :dim_z] if dim_z <= dim_s \\\n",
    "        else tc.cat([s_normal_flat, tc.zeros(shape_bat+(dim_z-dim_s,), dtype=s.dtype, device=s.device)], dim=-1)\n",
    "    return mu_z + corr_sz()*std_z() * z_normal_flat.reshape(shape_bat+shape_z)\n",
    "\n",
    "def std_z1s():\n",
    "    return std_z() * (1. - corr_sz()**2).sqrt()\n",
    "\n",
    "ds.Distr.clear()\n",
    "ds.Distr.default_device = device\n",
    "p_s = ds.Normal('s', mean=mu_s, std=std_s, shape=shape_s)\n",
    "p_z1s = ds.Normal('z',mean=mu_z1s, std=std_z1s, shape=shape_z)\n",
    "p_sz = p_s * p_z1s"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[-0.9941, -0.9730, -0.9932],\n",
      "        [-0.9953, -0.9977, -0.9990]], device='cuda:0')\n",
      "tensor([[1.2980, 1.2980, 1.3023],\n",
      "        [1.3075, 1.3032, 1.2968]], device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "# print(std_s_param.data)\n",
    "# print(mu_s.data, std_s().data, sep='\\n')\n",
    "opt = tc.optim.SGD([mu_s, std_s_param], lr=1e-3)\n",
    "for i in range(10000):\n",
    "    opt.zero_grad()\n",
    "    mlogp = -p_s.logp({'s': S}).mean()\n",
    "#     print(i, mlogp.data)\n",
    "    mlogp.backward()\n",
    "    opt.step()\n",
    "#     print(std_s_param.data)\n",
    "#     print(mu_s.data, std_s().data, sep='\\n')\n",
    "print(mu_s.data, std_s().data, sep='\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[-1.0678, -1.0178, -1.0556],\n",
      "        [-1.0065, -1.0068, -0.9978]], device='cuda:0')\n",
      "tensor([[1.3031, 1.3017, 1.3065],\n",
      "        [1.3065, 1.3031, 1.2968]], device='cuda:0')\n",
      "tensor([[0.9547, 0.9831],\n",
      "        [0.9609, 0.9943]], device='cuda:0')\n",
      "tensor([[1.0030, 1.0033],\n",
      "        [1.0072, 1.0018]], device='cuda:0')\n",
      "tensor([0.7024], device='cuda:0')\n",
      "tensor([[0.7140, 0.7141],\n",
      "        [0.7169, 0.7131]], device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "# print(std_s_param.data, std_z_param.data, corr_param.data, sep='\\n')\n",
    "# print('-'*5)\n",
    "# print(mu_s.data, std_s().data, mu_z.data, std_z().data, corr_sz().data, std_z1s().data, sep='\\n')\n",
    "opt = tc.optim.SGD([mu_s, std_s_param, mu_z, std_z_param, corr_param], lr=1e-3)\n",
    "for i in range(10000):\n",
    "    opt.zero_grad()\n",
    "    mlogp = -p_sz.logp({'s': S, 'z': Z}).mean()\n",
    "#     print(i, mlogp.data)\n",
    "    mlogp.backward()\n",
    "    opt.step()\n",
    "#     print(std_s_param.data, std_z_param.data, corr_param.data, sep='\\n')\n",
    "#     print('-'*5)\n",
    "#     print(mu_s.data, std_s().data, mu_z.data, std_z().data, corr_sz().data, std_z1s().data, sep='\\n')\n",
    "print(mu_s.data, std_s().data, mu_z.data, std_z().data, corr_sz().data, std_z1s().data, sep='\\n')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Learning by MVNormal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "mu_s = tc.randn(shape_s, requires_grad=True, device=device)\n",
    "std_s_diagpm = tc.randn(shape_s, requires_grad=True, device=device)\n",
    "std_s_offdiag = tc.randn(shape_s+shape_s[-1:], requires_grad=True, device=device)\n",
    "mu_z = tc.randn(shape_z, requires_grad=True, device=device)\n",
    "std_z_diagpm = tc.randn(shape_z, requires_grad=True, device=device)\n",
    "std_z_offdiag = tc.randn(shape_z+shape_z[-1:], requires_grad=True, device=device)\n",
    "corr_param = tc.randn(1, requires_grad=True, device=device)\n",
    "\n",
    "def corr_sz(): return 1. - (2*tc.sigmoid(corr_param)-1.)**2\n",
    "def std_s(): return tc.exp(std_s_diagpm).diag_embed() + std_s_offdiag.tril(diagonal=-1)\n",
    "def std_z(): return tc.exp(std_z_diagpm).diag_embed() + std_z_offdiag.tril(diagonal=-1)\n",
    "\n",
    "def mu_z1s(s):\n",
    "    s_normal = tc.triangular_solve((s - mu_s).unsqueeze(-1), std_s(), upper=False)[0].squeeze(-1)\n",
    "    s_normal_flat = s_normal.reshape(shape_bat+(dim_s,))\n",
    "    z_normal_flat = s_normal_flat[..., :dim_z] if dim_z <= dim_s \\\n",
    "        else tc.cat([s_normal_flat, tc.zeros(shape_bat+(dim_z-dim_s,), dtype=s.dtype, device=s.device)], dim=-1)\n",
    "    return mu_z + corr_sz() * (std_z() @ z_normal_flat.reshape(shape_bat+shape_z+(1,))).squeeze(-1)\n",
    "\n",
    "def std_z1s():\n",
    "    return std_z() * (1. - corr_sz()**2).sqrt()\n",
    "\n",
    "ds.Distr.clear()\n",
    "ds.Distr.default_device = device\n",
    "p_s = ds.MVNormal('s', mean=mu_s, std_tril=std_s, shape=shape_s)\n",
    "p_z1s = ds.MVNormal('z',mean=mu_z1s, std_tril=std_z1s, shape=shape_z)\n",
    "p_sz = p_s * p_z1s"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[-0.9912, -0.9900, -0.9952],\n",
      "        [-0.9905, -0.9530, -0.9386]], device='cuda:0')\n",
      "tensor([[[ 1.2981,  0.0000,  0.0000],\n",
      "         [-0.0110,  1.2976,  0.0000],\n",
      "         [ 0.0114, -0.0080,  1.3025]],\n",
      "\n",
      "        [[ 1.3076,  0.0000,  0.0000],\n",
      "         [-0.0060,  1.3123,  0.0000],\n",
      "         [ 0.0095,  0.0899,  1.3067]]], device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "opt = tc.optim.SGD([mu_s, std_s_diagpm, std_s_offdiag], lr=1e-3)\n",
    "for i in range(10000):\n",
    "    opt.zero_grad()\n",
    "    mlogp = -p_s.logp({'s': S}).mean()\n",
    "    mlogp.backward()\n",
    "    opt.step()\n",
    "print(mu_s.data, std_s().data, sep='\\n')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[-0.9969, -1.0019, -1.0026],\n",
      "        [-1.0011, -1.0066, -0.9976]], device='cuda:0')\n",
      "tensor([[[ 1.2988e+00,  0.0000e+00,  0.0000e+00],\n",
      "         [-3.6102e-04,  1.2994e+00,  0.0000e+00],\n",
      "         [ 2.0667e-03,  3.0024e-03,  1.3029e+00]],\n",
      "\n",
      "        [[ 1.3043e+00,  0.0000e+00,  0.0000e+00],\n",
      "         [-1.8239e-02,  1.3030e+00,  0.0000e+00],\n",
      "         [-1.4886e-03, -9.8890e-03,  1.2968e+00]]], device='cuda:0')\n",
      "tensor([[1.0038, 0.9942],\n",
      "        [0.9977, 0.9980]], device='cuda:0')\n",
      "tensor([[[1.0001e+00, 0.0000e+00],\n",
      "         [4.5207e-03, 1.0015e+00]],\n",
      "\n",
      "        [[1.0047e+00, 0.0000e+00],\n",
      "         [2.6764e-04, 1.0001e+00]]], device='cuda:0')\n",
      "tensor([0.7009], device='cuda:0')\n",
      "tensor([[[7.1335e-01, 0.0000e+00],\n",
      "         [3.2245e-03, 7.1434e-01]],\n",
      "\n",
      "        [[7.1660e-01, 0.0000e+00],\n",
      "         [1.9090e-04, 7.1335e-01]]], device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "opt = tc.optim.SGD([mu_s, std_s_diagpm, std_s_offdiag, mu_z, std_z_diagpm, std_z_offdiag, corr_param], lr=1e-3)\n",
    "for i in range(10000):\n",
    "    opt.zero_grad()\n",
    "    mlogp = -p_sz.logp({'s': S, 'z': Z}).mean()\n",
    "    mlogp.backward()\n",
    "    opt.step()\n",
    "print(mu_s.data, std_s().data, mu_z.data, std_z().data, corr_sz().data, std_z1s().data, sep='\\n')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
